import time
from os.path import join

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import DataParallel
from torch.utils.data import DataLoader
from tqdm import tqdm

from .base import Trainer


class CoresTrainer(Trainer):
    def __init__(
        self,
        config,
        model,
        logger,
        train_set,
        test_set,
        criterion,
        optimizer,
        num_classes,
        init_noise_prior,
        scheduler=None,
        val_set=None,
    ):
        super().__init__(
            config,
            model,
            logger,
            train_set,
            test_set,
            criterion,
            optimizer,
            scheduler,
            val_set,
        )
        self.num_training_samples = len(train_set)
        self.num_classes = num_classes
        self.cur_noise_prior = init_noise_prior

    def run(self):
        for cur_epoch in range(self.epoch):
            self.model.train()
            train_total = 0
            train_correct = 0
            v_list = np.zeros(self.num_training_samples)
            idx_each_class_noisy = [[] for i in range(self.num_classes)]
            if not isinstance(self.cur_noise_prior, torch.Tensor):
                noise_prior = (
                    torch.tensor(self.cur_noise_prior.astype("float32"))
                    .cuda()
                    .unsqueeze(0)
                )
            for i, data in enumerate(self.train_loader):
                inputs, labels, attribute, index = self.prepare_data(data)
                batch_size = len(index)
                class_list = range(self.num_classes)
                logits = self.model(inputs)
                prec, _ = self.accuracy(logits, labels, topk=(1, 5))
                train_total += 1
                train_correct += prec
                loss, loss_v = self.criterion(
                    cur_epoch,
                    logits,
                    labels,
                )
                v_list[index] = loss_v
                for i in range(batch_size):
                    if loss_v[i] == 0:
                        idx_each_class_noisy[labels[i]].append(index[i])
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
        class_size_noisy = [
            len(idx_each_class_noisy[i]) for i in range(self.num_classes)
        ]
        noise_prior_delta = np.array(class_size_noisy)
        train_acc = float(train_correct) / float(train_total)
        self.cur_noise_prior = (
            noise_prior * self.num_training_samples - noise_prior_delta
        )
        self.cur_noise_prior = self.cur_noise_prior / sum(self.cur_noise_prior)
